{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "tags": [
     "remove-cell"
    ]
   },
   "outputs": [],
   "source": [
    "import set_env"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# reshape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "tags": [
     "hide-cell"
    ]
   },
   "outputs": [],
   "source": [
    "from pathlib import Path\n",
    "import numpy as np\n",
    "import onnx\n",
    "from onnx import TensorProto, helper, mapping, numpy_helper\n",
    "import onnxruntime\n",
    "import tvm\n",
    "from tvm import relay\n",
    "\n",
    "def get_input_data_shape_dict(graph_def, input_data):\n",
    "    \"\"\"Get input data shape\"\"\"\n",
    "    if isinstance(input_data, list):\n",
    "        input_names = {}\n",
    "        shape_dict = {}\n",
    "        for i, _ in enumerate(input_data):\n",
    "            input_names[i] = graph_def.graph.input[i].name\n",
    "            input_ = input_data[i]\n",
    "\n",
    "            if input_ is None or not hasattr(input_, \"shape\") or input_.shape == ():\n",
    "                # Skip adding input shape data when the input data is None;\n",
    "                # This is to enable optional arguments for onnx operators.\n",
    "                continue\n",
    "\n",
    "            elif isinstance(input_, list):\n",
    "                shape_dict[input_names[i]] = (len(input_),)\n",
    "\n",
    "            else:\n",
    "                shape_dict[input_names[i]] = input_.shape\n",
    "\n",
    "    else:\n",
    "        input_names = graph_def.graph.input[0].name\n",
    "        shape_dict = {input_names: input_data.shape}\n",
    "\n",
    "    return input_names, shape_dict\n",
    "\n",
    "def get_onnxruntime_output(graph_def, inputs):\n",
    "    \"\"\"Generic function to generate onnxruntime output\"\"\"\n",
    "    # rep = onnxruntime.backend.prepare(graph_def.SerializeToString(), 'CPU', providers=['CPUExecutionProvider'])\n",
    "    sess = onnxruntime.InferenceSession(\n",
    "        graph_def.SerializeToString(), providers=['CPUExecutionProvider']\n",
    "    )\n",
    "    for x, data in zip(sess.get_inputs(), inputs):\n",
    "        input_names[x.name] = data\n",
    "    output_names = [out.name for out in sess.get_outputs()]\n",
    "    output = sess.run(output_names, input_names)\n",
    "    return output\n",
    "\n",
    "def get_tvm_output(\n",
    "    graph_def,\n",
    "    input_data,\n",
    "    target,\n",
    "    dev,\n",
    "    opset=None,\n",
    "    freeze_params=False,\n",
    "    convert_config=None,):\n",
    "    if not isinstance(input_data, list):\n",
    "        input_data = [input_data]\n",
    "    _, shape_dict = get_input_data_shape_dict(graph_def, input_data)\n",
    "    mod, params = relay.frontend.from_onnx(\n",
    "        graph_def,\n",
    "        shape_dict,\n",
    "        opset=opset,\n",
    "        freeze_params=freeze_params,\n",
    "        convert_config=convert_config,\n",
    "    )\n",
    "    result = relay.create_executor(\"vm\", mod=mod, device=dev, target=target).evaluate()(\n",
    "        *input_data, **params\n",
    "    )\n",
    "    if isinstance(result, tvm.runtime.NDArray):\n",
    "        return [result.numpy()]\n",
    "    return [r.numpy() for r in result]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "配置："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "temp_dir = Path(\".temp\")\n",
    "temp_dir.mkdir(exist_ok=True)\n",
    "model_path = f\"{temp_dir}/Reshape.onnx\" # 模型存储路径"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "in_shape = (4, 3, 3, 4)\n",
    "ref_shape = (6, 2, 4, 3)\n",
    "\n",
    "ref_array = np.array(ref_shape)\n",
    "ref_node = onnx.helper.make_node(\n",
    "    \"Constant\",\n",
    "    inputs=[],\n",
    "    outputs=[\"ref_in\"],\n",
    "    value=onnx.helper.make_tensor(\n",
    "        name=\"const_tensor\",\n",
    "        data_type=onnx.TensorProto.INT32,\n",
    "        dims=ref_array.shape,\n",
    "        vals=ref_array.flatten().astype(int),\n",
    "    ),\n",
    ")\n",
    "reshape_node = helper.make_node(\"Reshape\", [\"in\", \"ref_in\"], [\"out\"])\n",
    "\n",
    "graph = helper.make_graph(\n",
    "    [ref_node, reshape_node],\n",
    "    \"reshape_test\",\n",
    "    inputs=[helper.make_tensor_value_info(\"in\", TensorProto.FLOAT, list(in_shape))],\n",
    "    outputs=[helper.make_tensor_value_info(\"out\", TensorProto.FLOAT, list(ref_shape))],\n",
    ")\n",
    "\n",
    "graph_def = helper.make_model(graph, producer_name=\"reshape_test\")\n",
    "onnx.save(graph_def, model_path) # 模型存储"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "ename": "InvalidGraph",
     "evalue": "[ONNXRuntimeError] : 10 : INVALID_GRAPH : This is an invalid model. Type Error: Type 'tensor(int32)' of input parameter (ref_in) of operator (Reshape) in node (Reshape_0) is invalid.",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mInvalidGraph\u001b[0m                              Traceback (most recent call last)",
      "Cell \u001b[0;32mIn[5], line 9\u001b[0m\n\u001b[1;32m      3\u001b[0m inputs \u001b[38;5;241m=\u001b[39m np\u001b[38;5;241m.\u001b[39mrandom\u001b[38;5;241m.\u001b[39muniform(size\u001b[38;5;241m=\u001b[39min_shape)\u001b[38;5;241m.\u001b[39mastype(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mint32\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n\u001b[1;32m      4\u001b[0m tvm_result \u001b[38;5;241m=\u001b[39m get_tvm_output(graph_def,\n\u001b[1;32m      5\u001b[0m     inputs,\n\u001b[1;32m      6\u001b[0m     target,\n\u001b[1;32m      7\u001b[0m     dev\n\u001b[1;32m      8\u001b[0m )\n\u001b[0;32m----> 9\u001b[0m ort_out \u001b[38;5;241m=\u001b[39m get_onnxruntime_output(graph_def, inputs)\n",
      "Cell \u001b[0;32mIn[2], line 38\u001b[0m, in \u001b[0;36mget_onnxruntime_output\u001b[0;34m(graph_def, inputs)\u001b[0m\n\u001b[1;32m     36\u001b[0m \u001b[38;5;250m\u001b[39m\u001b[38;5;124;03m\"\"\"Generic function to generate onnxruntime output\"\"\"\u001b[39;00m\n\u001b[1;32m     37\u001b[0m \u001b[38;5;66;03m# rep = onnxruntime.backend.prepare(graph_def.SerializeToString(), 'CPU', providers=['CPUExecutionProvider'])\u001b[39;00m\n\u001b[0;32m---> 38\u001b[0m sess \u001b[38;5;241m=\u001b[39m onnxruntime\u001b[38;5;241m.\u001b[39mInferenceSession(\n\u001b[1;32m     39\u001b[0m     graph_def\u001b[38;5;241m.\u001b[39mSerializeToString(), providers\u001b[38;5;241m=\u001b[39m[\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mCPUExecutionProvider\u001b[39m\u001b[38;5;124m'\u001b[39m]\n\u001b[1;32m     40\u001b[0m )\n\u001b[1;32m     41\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m x, data \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mzip\u001b[39m(sess\u001b[38;5;241m.\u001b[39mget_inputs(), inputs):\n\u001b[1;32m     42\u001b[0m     input_names[x\u001b[38;5;241m.\u001b[39mname] \u001b[38;5;241m=\u001b[39m data\n",
      "File \u001b[0;32m/media/pc/data/lxw/envs/anaconda3x/envs/xxx/lib/python3.12/site-packages/onnxruntime/capi/onnxruntime_inference_collection.py:465\u001b[0m, in \u001b[0;36mInferenceSession.__init__\u001b[0;34m(self, path_or_bytes, sess_options, providers, provider_options, **kwargs)\u001b[0m\n\u001b[1;32m    462\u001b[0m disabled_optimizers \u001b[38;5;241m=\u001b[39m kwargs\u001b[38;5;241m.\u001b[39mget(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mdisabled_optimizers\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n\u001b[1;32m    464\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[0;32m--> 465\u001b[0m     \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_create_inference_session(providers, provider_options, disabled_optimizers)\n\u001b[1;32m    466\u001b[0m \u001b[38;5;28;01mexcept\u001b[39;00m (\u001b[38;5;167;01mValueError\u001b[39;00m, \u001b[38;5;167;01mRuntimeError\u001b[39;00m) \u001b[38;5;28;01mas\u001b[39;00m e:\n\u001b[1;32m    467\u001b[0m     \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_enable_fallback:\n",
      "File \u001b[0;32m/media/pc/data/lxw/envs/anaconda3x/envs/xxx/lib/python3.12/site-packages/onnxruntime/capi/onnxruntime_inference_collection.py:528\u001b[0m, in \u001b[0;36mInferenceSession._create_inference_session\u001b[0;34m(self, providers, provider_options, disabled_optimizers)\u001b[0m\n\u001b[1;32m    526\u001b[0m     sess \u001b[38;5;241m=\u001b[39m C\u001b[38;5;241m.\u001b[39mInferenceSession(session_options, \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_model_path, \u001b[38;5;28;01mTrue\u001b[39;00m, \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_read_config_from_model)\n\u001b[1;32m    527\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m--> 528\u001b[0m     sess \u001b[38;5;241m=\u001b[39m C\u001b[38;5;241m.\u001b[39mInferenceSession(session_options, \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_model_bytes, \u001b[38;5;28;01mFalse\u001b[39;00m, \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_read_config_from_model)\n\u001b[1;32m    530\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m disabled_optimizers \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[1;32m    531\u001b[0m     disabled_optimizers \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mset\u001b[39m()\n",
      "\u001b[0;31mInvalidGraph\u001b[0m: [ONNXRuntimeError] : 10 : INVALID_GRAPH : This is an invalid model. Type Error: Type 'tensor(int32)' of input parameter (ref_in) of operator (Reshape) in node (Reshape_0) is invalid."
     ]
    }
   ],
   "source": [
    "target = \"llvm\"\n",
    "dev = tvm.device(target)\n",
    "inputs = np.random.uniform(size=in_shape).astype(\"int32\")\n",
    "tvm_result = get_tvm_output(graph_def,\n",
    "    inputs,\n",
    "    target,\n",
    "    dev\n",
    ")\n",
    "ort_out = get_onnxruntime_output(graph_def, inputs)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
